Skip to content

Fix CuTe composition stride-divisibility check (#3177)#3181

Open
jduprat wants to merge 2 commits into
NVIDIA:mainfrom
jduprat:main
Open

Fix CuTe composition stride-divisibility check (#3177)#3181
jduprat wants to merge 2 commits into
NVIDIA:mainfrom
jduprat:main

Conversation

@jduprat

@jduprat jduprat commented Apr 21, 2026

Copy link
Copy Markdown

composition_impl() used a strict weakening of the divisibility condition:
it accepted any rhs stride smaller than the current lhs mode shape,
regardless of whether the shape was actually divisible by the stride.

For A=(4,6,8):(2,3,5), B=6:3, this lets composition(A,B) compile and
return (_2,_3):(_6,_3), but C(2)=3 != A(B(2))=7.

Replace the weak check with the stronger condition used by
pycute (layout.py:211).

Fixes #3177

jduprat added 2 commits April 30, 2026 21:21
composition_impl() used a strict weakening of the divisibility condition:
it accepted any rhs stride smaller than the current lhs mode shape,
regardless of whether the shape was actually divisible by the stride.

For A=(4,6,8):(2,3,5), B=6:3, this lets composition(A,B) compile and
return (_2,_3):(_6,_3), but C(2)=3 != A(B(2))=7.

Replace the weak check with the stronger condition used by
pycute (layout.py:211).

Fixes NVIDIA#3177
The strong divisibility check from the previous commit fixes the
wrong-answer composition from NVIDIA#3177, but rejects the paper's §3.3.3
"apparent violation" cases that produce well-defined results, e.g.

    A = (4,2,8):(3,12,97), B = 3:3   ->   3:9

After the public composition() coalesces A to (8,8):(3,97), the
strong check sees `8 % 3 != 0` and refuses to compile, even though
A(0)=0, A(3)=9, A(6)=18 is well-defined.

Add a third disjunct that accepts the safe-truncation pattern: when
B's entire image fits inside the current LHS mode, higher modes are
unreachable and cannot perturb the result. This is the §3.3.3
distinction between "apparent" and "real" divisibility violations.

Predicate now accepts iff at least one of:
  (a) (rest_stride % curr_shape) == 0   -- skip mode entirely
  (b) (curr_shape % rest_stride) == 0   -- partial traversal
  (c) (rest_shape - 1) * rest_stride < curr_shape
                                        -- safe truncation: B's image
                                           stays within the current mode

Verification matrix:

  Case                                  Pre-coalesce LHS         Decision
  ----------------------------------    ---------------------    --------
  paper §3.3.3 ok (returns 3:9)         (8,8):(3,97) o 3:3       accept
  paper §3.3.3 fail-left                (8,8):(3,97) o 4:3       reject
  paper §3.3.3 fail-right               (4,2,8):(3,15,97) o 3:3  reject
  wrong-answer bug NVIDIA#3177                (4,6,8):(2,3,5) o 6:3    reject
  CuTe test (8,8):(8,1) o 2:3           (8,8):(8,1) o 2:3        accept
  CuTe test (8,8):(8,1) o 3:3           (8,8):(8,1) o 3:3        accept
  CuTe test (8,8):(8,1) o 4:3           (8,8):(8,1) o 4:3        reject

Reference: arXiv:2603.02298 §3.3.3.
@github-actions

Copy link
Copy Markdown

This PR has been labeled inactive-30d due to no recent activity in the past 30 days. Please close this PR if it is no longer required. Otherwise, please respond with a comment indicating any updates. This PR will be labeled inactive-90d if there is no activity in the next 60 days.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[BUG] Discrepancy between CuTe C++ and pycute

1 participant